-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3 hybrid implementation using submeshes #18777
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean 👌
To do:
- Add at least one CI test that will exercise DP. I suggest adding a demo to the t3k tests.
if is_ci_env and num_devices == 8 and data_parallel > 1 and not ("3.2-1B" in llama_dir or "3.1-8B" in llama_dir): | ||
pytest.skip("CI runs only hybrid Llama3 1b and 8b on T3K") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about 3B?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted to avoid burdening the CI with additional tests. 1B and 8B seemed okay to cover perf regression checks as the smallest and biggest variants of the smaller Llama3 models. Should we add 3B anyway?
return data_parallel, mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel)) | ||
|
||
|
||
def allocate_kv_cache(kv_cache_shape, dtype, num_layers, mesh_device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO (@ipotkonjak-tt and/or @skhorasganiTT) Modify KV creation in vLLM to use this function and test with DP
Problem description
Missing support for data / hybrid parallelism for Llama3 models.
What's changed
Addition of hybrid parallelism within llama code base with concept of submeshes. Implementation is mainly based at the LlamaGenerator level. MeshDevice is partitioned into submeshes where each subset of devices has an independent model. Models remain implemented in the tensor parallel manner.
Checklist